#!/usr/bin/env python3
import asyncio
import argparse
import json
import os
import re
import time
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Any, Union
from datetime import datetime
import aiohttp  

# ====================== Configuration ======================
@dataclass
class Config:
    dataset_path: str = "StrategyQA.json"
    model: str = "your model"  
    base_url: str = "your base_url"  

# ====================== LLM Client ======================
class LLMClient:
    """Wrapper for LLM API using aiohttp"""
    def __init__(self, config: Config):
        self.config = config
        self.token_counts = [0, 0]  
    
    async def generate(self, prompt: str) -> str:
        """Generate response from LLM using aiohttp"""
        try:
            async with aiohttp.ClientSession() as session:
                payload = {
                    "model": self.config.model,
                    "messages": [{"role": "user", "content": prompt}],
                    "temperature": 0.3,
                    "max_tokens": 8000,
                    "top_p": 0.8
                }
                
                async with session.post(
                    f"{self.config.base_url}/chat/completions",
                    json=payload,
                    timeout=aiohttp.ClientTimeout(total=120)
                ) as response:
                    resp = await response.json()
                    
                    input_tokens = len(prompt) // 4
                    output_tokens = len(resp["choices"][0]["message"]["content"]) // 4
                    self.token_counts[0] += input_tokens
                    self.token_counts[1] += output_tokens
                    
                    return resp["choices"][0]["message"]["content"]
        except Exception as e:
            print(f"LLM Error: {str(e)}")
            raise

# ====================== Core Solver ======================
class StrategyQASolver:
    """Chain-of-Thought StrategyQA Solver"""
    def __init__(self):
        self.config = Config()
        self.llm = LLMClient(self.config)
        self.stats = {
            "total_problems": 0,
            "correct_answers": 0,
            "accuracy": 0.0,
            "tokens_used": [0, 0]
        }
    
    def _extract_answer(self, text: str) -> Optional[bool]:
        """Extract true/false answer from response text with multiple patterns"""
        patterns = [
            r'Final Answer:\s*(true|false)',  # Final Answer: true
            r'Answer:\s*(true|false)',        # Answer: false
            r'Correct Answer:\s*(true|false)', # Correct Answer: true
            r'\(?(true|false)\)?',            # (true) or false
            r'\[?(true|false)\]?',             # [true] or false
            r'\{?(true|false)\}?',            # {true} or false
            r'\b(true|false)\b',              # standalone true/false
        ]
        
        for pattern in patterns:
            match = re.search(pattern, text, re.IGNORECASE)
            if match:
                return match.group(1).lower() == 'true'
        
        # Fallback: look for the last occurrence of true/false in the text
        last_option_match = re.findall(r'(true|false)', text, re.IGNORECASE)
        if last_option_match:
            return last_option_match[-1].lower() == 'true'
        
        return None

    def _verify_answer(self, problem: Dict[str, Any], selected_answer: bool) -> bool:
        """Verify if selected answer matches correct answer"""
        correct_answer = problem.get("answer", False)
        return selected_answer == correct_answer
    
    async def load_problems(self, start_idx: int, end_idx: int) -> List[Dict]:
        """Load StrategyQA problems from dataset"""
        try:
            with open(self.config.dataset_path, "r", encoding="utf-8") as f:
                data = json.load(f)
                return data[start_idx:end_idx]
        except Exception as e:
            print(f"Error loading dataset: {str(e)}")
            return []
    
    async def solve_problem(self, problem: Dict[str, Any]) -> Dict[str, Any]:
        """Solve a problem using Chain-of-Thought approach"""
        question = problem["question"]
        facts = "\n".join([f"- {fact}" for fact in problem.get("facts", [])])
        
        prompt = f""" 
Question: {question}
Supporting Facts:
{facts}

Let's think step by step to solve the question. Analyze the facts carefully and determine if the answer is true or false. 
Give the correct answer by stating "The correct answer is [X]" where [X] is exactly either "true" or "false".
for example:

Question: Does highest US Court have enough seats for every Prime Minister of the United Kingdom since 1952?
facts: [
            "The highest court in the US is the Supreme Court.",
            "There are nine seats on the Supreme Court.",
            "There have been fifteen Prime Ministers of the United Kingdom since 1952."
        ]
Answer: The correct answer is false

Question: Was Pi an acceptable number of children in 1980s China?
facts: [
            "Pi, the ratio of a circle's circumference to diameter, is equal to 3.14.",
            "In the 1980's China instituted a one-child policy.",
            "People that violated China's one child policy were fined heavily and some were sterilized."
        ]
Answer: The correct answer is false

Then answer the question: 
"""
        response = await self.llm.generate(prompt)
        answer = self._extract_answer(response)
        is_correct = self._verify_answer(problem, answer) if answer is not None else False
        
        # Update statistics
        self.stats["total_problems"] += 1
        if is_correct:
            self.stats["correct_answers"] += 1
        self.stats["accuracy"] = (self.stats["correct_answers"] / self.stats["total_problems"] * 100) if self.stats["total_problems"] > 0 else 0
        self.stats["tokens_used"] = self.llm.token_counts.copy()
        
        return {
            "question": question,
            "facts": problem.get("facts", []),
            "correct_answer": problem.get("answer", False),
            "response": response,
            "answer": answer,
            "is_correct": is_correct,
            "tokens_used": self.llm.token_counts.copy()
        }

# ====================== Main Execution ======================
async def main():
    parser = argparse.ArgumentParser(description="Chain-of-Thought StrategyQA Problem Solver")
    parser.add_argument("--start", type=int, default=0, help="Start index in dataset")
    parser.add_argument("--end", type=int, default=5, help="End index in dataset")
    args = parser.parse_args()
    
    os.makedirs("log/StrategyQA_cot", exist_ok=True)
    solver = StrategyQASolver()
    problems = await solver.load_problems(args.start, args.end)
    
    all_results = []
    for idx, problem in enumerate(problems):
        print(f"\n{'='*50}\nProcessing problem {idx}: {problem['question'][:50]}...\n{'='*50}")
        
        result = await solver.solve_problem(problem)
        all_results.append(result)
        
        # Print results for this problem
        print(f"Question: {result['question'][:100]}...")
        print(f"Selected Answer: {result.get('answer', '?')}")
        print(f"Correct Answer: {result['correct_answer']}")
        print(f"Correct: {result.get('is_correct', False)}")
        print(f"Tokens used: Input={result['tokens_used'][0]}, Output={result['tokens_used'][1]}")
    
    # Save final results
    if all_results:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"log/StrategyQA_cot/results_{args.start}_{args.end}_acc{solver.stats['accuracy']:.2f}%.json"
        
        with open(filename, "w", encoding="utf-8") as f:
            json.dump({
                "results": all_results,
                "stats": solver.stats
            }, f, indent=2, ensure_ascii=False)
        
        print(f"\n{'='*50}")
        print(f"Results saved to {filename}")
        print(f"Total Problems: {solver.stats['total_problems']}")
        print(f"Correct Answers: {solver.stats['correct_answers']}")
        print(f"Accuracy: {solver.stats['accuracy']:.2f}%")
        print(f"Total Tokens Used: Input={solver.stats['tokens_used'][0]}, Output={solver.stats['tokens_used'][1]}")
        print(f"{'='*50}")

if __name__ == "__main__":
    asyncio.run(main())